/*
MIT License

Copyright (c) 2021 МГТУ им. Н.Э. Баумана, кафедра ИУ-6, Михаил Фетисов,

https://bmstu.codes/lsx/simodo
*/

#include "simodo/parser/Grammar.h"
#include "simodo/ast/NodeFriend.h"
#include "simodo/inout/convert/functions.h"

#include <fstream>
#include <cassert>

#if __cplusplus >= __cpp_2017
#include <filesystem>
namespace fs = std::filesystem;
#else
#include <experimental/filesystem>
namespace fs = std::filesystem::experimental;
#endif

namespace simodo::parser
{

    size_t Grammar::getColumnIndex(const inout::Lexeme &lexeme) const
    {
        /// \attention Этот метод используется до заполнения индексов terminal_symbol_index и compound_symbol_index!

        assert(first_compound_index > 0 && first_compound_index < columns.size());

        size_t i   = (lexeme.type() != inout::LexemeType::Compound) ? 0 : first_compound_index;
        size_t end = (lexeme.type() != inout::LexemeType::Compound) ? first_compound_index : columns.size();

        for(; i < end; ++i)
            if (lexeme == columns[i])
                return i;

        return columns.size();
    }

    size_t Grammar::getTerminalColumnIndex(const inout::Lexeme &lexeme) const
    {
        assert(first_compound_index > 0 && first_compound_index < columns.size());

        for(size_t i=0; i < first_compound_index; ++i)
        {
            if (lexeme.type() == inout::LexemeType::Id && columns[i].type() == inout::LexemeType::Id)
                return i;
            if (lexeme.type() == inout::LexemeType::Number && columns[i].type() == inout::LexemeType::Number)
                return i;
            if (lexeme.type() == inout::LexemeType::Comment && columns[i].type() == inout::LexemeType::Comment)
                return i;
            if (lexeme.type() == inout::LexemeType::Annotation && columns[i].type() == inout::LexemeType::Annotation)
                return i;
            if (lexeme == columns[i])
                return i;
        }

        return columns.size();
    }

    size_t Grammar::getCompoundColumnIndex(const std::u16string &str) const
    {
        assert(first_compound_index > 0 && first_compound_index < columns.size());

        for(size_t i=first_compound_index; i < columns.size(); ++i)
            if (str == columns[i].lexeme())
                return i;

        return columns.size();
    }

    std::u16string getFsmActionChar(FsmActionType action)
    {
        switch(action)
        {
        case FsmActionType::Error:
            return u"E";
        case FsmActionType::Shift:
            return u"S";
        case FsmActionType::Reduce:
            return u"R";
        case FsmActionType::Acceptance:
            return u"A";
        }
        return u"*";
    }

    const char * getGrammarBuilderMethodName(TableBuildMethod method)
    {
        switch(method)
        {
        case TableBuildMethod::SLR:
            return "SLR";
        case TableBuildMethod::LR1:
            return "LR1";
    //    case TableBuildMethod::LALR:
    //        return u"LALR";
        case TableBuildMethod::none:
            return "none";
        default:
            return "*****";
        }
    }

    bool fillLexemeParameters(parser::Grammar &g)
    {
        assert(!g.columns.empty());
        assert(g.first_compound_index < g.columns.size());

        std::u16string punct;

        for(size_t i=0; i < g.first_compound_index; ++i)
        {
            if (g.columns[i].type() != inout::LexemeType::Punctuation)
                continue;

            const std::u16string & lexeme = g.columns[i].lexeme();

            if (lexeme.size() == 1)
                punct += lexeme;
            else if (!lexeme.empty())
                g.lexical.punctuation_words.emplace_back(lexeme);
        }

        g.lexical.punctuation_chars = punct;

        return true;
    }

    namespace
    {
        const std::u16string GRAMMAR_DUMP_SIGNATURE   = u"SIMODO.loom.grammar.dump.v03";
        const size_t         REASONABLE_STRING_LENGTH = 2000;
        const size_t         MAX_STRING_LENGTH        = REASONABLE_STRING_LENGTH * 2;

        const std::string    GRAMMAR_DUMP_DIRECTORY   = "dump";
        const std::string    GRAMMAR_DUMP_EXTENTION   = ".simodo-grammar-dump";

        bool storeString(std::ostream &os, const std::u16string &s)
        {
            size_t  len = s.size()*sizeof(s[0]);

            assert(len < MAX_STRING_LENGTH);

            os.write(reinterpret_cast<char *>(&len), sizeof(len));

            if (len > 0)
                os.write(reinterpret_cast<const char *>(s.c_str()), static_cast<std::streamsize>(len));

            return true;
        }

        std::u16string loadString(std::istream &is)
        {
            size_t  len;
            is.read(reinterpret_cast<char *>(&len), sizeof(len));

            assert(len <= MAX_STRING_LENGTH);

            char16_t b[MAX_STRING_LENGTH+1];

            if (len > 0)
                is.read(reinterpret_cast<char *>(b), static_cast<std::streamsize>(len));

            b[len/sizeof(b[0])] = 0;

            return std::u16string(b);
        }

        bool storeLexeme(std::ostream & os, const inout::Lexeme & s)
        {
            storeString(os, s.lexeme());

            inout::LexemeType type = s.type();
            os.write(reinterpret_cast<char *>(&type), sizeof(type));

            return true;
        }

        inout::Lexeme loadLexeme(std::istream &is)
        {
            std::u16string    value(loadString(is));
            inout::LexemeType type;

            is.read(reinterpret_cast<char *>(&type), sizeof(type));

            return {value.c_str(), type};
        }

        bool storeToken(std::ostream &os, const inout::Token &t)
        {
            storeLexeme(os, t);
            storeString(os, t.token());

            const inout::Range range = t.location().range();
            os.write(reinterpret_cast<const char *>(&range), sizeof(range));
            const inout::uri_index_t uri_index = t.location().uri_index();
            os.write(reinterpret_cast<const char *>(&uri_index), sizeof(uri_index));

            inout::TokenQualification q = t.qualification();
            os.write(reinterpret_cast<const char *>(&q), sizeof(q));

            return true;
        }

        inout::Token loadToken(std::istream & is)
        {
            inout::Lexeme       lexeme       = loadLexeme(is);
            std::u16string      token_string = loadString(is);
            inout::Range        range;
            inout::uri_index_t  uri_index;

            is.read(reinterpret_cast<char *>(&range), sizeof(range));
            is.read(reinterpret_cast<char *>(&uri_index), sizeof(uri_index));

            inout::TokenLocation loc(uri_index,range);

            inout::TokenQualification q;
            is.read(reinterpret_cast<char *>(&q), sizeof(q));

            return { lexeme, token_string, loc, q };
        }

        bool storeAstNode(std::ostream & os, const ast::Node & node)
        {
            storeString(os, node.host());
            ast::OperationCode op = node.operation();
            os.write(reinterpret_cast<const char *>(&op), sizeof(op));

            storeToken(os, node.token());
            storeToken(os, node.bound());

            size_t branch_element_count = node.branches().size();
            os.write(reinterpret_cast<const char *>(&branch_element_count), sizeof(branch_element_count));
            for(const ast::Node & bn : node.branches())
                storeAstNode(os, bn);

            return true;
        }

        ast::Node loadAstNode(std::istream & is)
        {
            std::u16string host_name = loadString(is);
            ast::OperationCode op;
            is.read(reinterpret_cast<char *>(&op), sizeof(op));

            inout::Token symbol = loadToken(is);
            inout::Token bound = loadToken(is);

            ast::Node node(host_name, op, symbol, bound);

            size_t branch_element_count;
            is.read(reinterpret_cast<char *>(&branch_element_count), sizeof(branch_element_count));

            ast::NodeFriend::branches(node).reserve(branch_element_count);
            for(size_t i=0; i < branch_element_count; ++i)
                ast::NodeFriend::branches(node).emplace_back(loadAstNode(is));

            return node;
        }

        bool storeMarkup(std::ostream & os, const inout::MarkupSymbol & ms)
        {
            storeString(os,ms.start);
            storeString(os,ms.end);
            storeString(os,ms.ignore_sign);

            inout::LexemeType  type = ms.type;
            os.write(reinterpret_cast<char *>(&type), sizeof(type));

            return true;
        }

        inout::MarkupSymbol loadMarkup(std::istream & is)
        {
            std::u16string start        = loadString(is);
            std::u16string end          = loadString(is);
            std::u16string ignore_sign  = loadString(is);

            inout::LexemeType type;
            is.read(reinterpret_cast<char *>(&type), sizeof(type));

            return {start, end, ignore_sign, type};
        }

        bool storeMask(std::ostream & os, const inout::NumberMask & mask)
        {
            storeString(os,mask.chars);

            inout::LexemeType type = mask.type;
            os.write(reinterpret_cast<const char *>(&type), sizeof(type));

            inout::number_system_t system = mask.system;
            os.write(reinterpret_cast<const char *>(&system), sizeof(system));

            return true;
        }

        inout::NumberMask loadMask(std::istream &is)
        {
            std::u16string chars        = loadString(is);

            inout::LexemeType type;
            is.read(reinterpret_cast<char *>(&type), sizeof(type));

            inout::number_system_t system;
            is.read(reinterpret_cast<char *>(&system), sizeof(system));

            return {chars, type, system};
        }

    }

    bool saveGrammarDump(const std::string & grammar_file, const Grammar & g)
    {
        fs::path grammar_dump_path { fs::path(grammar_file).parent_path() / GRAMMAR_DUMP_DIRECTORY };

        if (!fs::exists(grammar_dump_path) && !fs::create_directory(grammar_dump_path)) 
            return false;

        grammar_dump_path /= fs::path(grammar_file).stem().string() + GRAMMAR_DUMP_EXTENTION;

        std::ofstream os(grammar_dump_path, std::ios::binary);

        if (!os)
            return false;

        // LSX_GRAMMAR_DUMP_SIGNATURE
        storeString(os,GRAMMAR_DUMP_SIGNATURE);
        // build_method
        os.write(reinterpret_cast<const char *>(&g.build_method), sizeof(g.build_method));
        // lexical
        size_t markups_count = g.lexical.markups.size();
        os.write(reinterpret_cast<char *>(&markups_count), sizeof(markups_count));
        for(const inout::MarkupSymbol & ms : g.lexical.markups)
            storeMarkup(os, ms);
        size_t masks_count = g.lexical.masks.size();
        os.write(reinterpret_cast<char *>(&masks_count), sizeof(masks_count));
        for(const inout::NumberMask & mask : g.lexical.masks)
            storeMask(os, mask);
        storeString(os, g.lexical.national_alphabet);
        storeString(os, g.lexical.id_extra_symbols);
        uint16_t may_national_letters_use = g.lexical.may_national_letters_use;
        os.write(reinterpret_cast<char *>(&may_national_letters_use), sizeof(may_national_letters_use));
        uint16_t may_national_letters_mix = g.lexical.may_national_letters_mix;
        os.write(reinterpret_cast<char *>(&may_national_letters_mix), sizeof(may_national_letters_mix));
        uint16_t is_case_sensitive = g.lexical.is_case_sensitive;
        os.write(reinterpret_cast<char *>(&is_case_sensitive), sizeof(is_case_sensitive));
        storeString(os, g.lexical.nl_substitution);
        // files
        size_t files_count = g.files.size();
        os.write(reinterpret_cast<char *>(&files_count), sizeof(files_count));
        for(const std::string & file_path : g.files)
            storeString(os, inout::toU16(file_path));
        // handlers
        size_t handlers_count = g.handlers.size();
        os.write(reinterpret_cast<char *>(&handlers_count), sizeof(handlers_count));
        for(const auto & [handler_name,handler_ast] : g.handlers)
        {
            storeString(os, handler_name);
            storeAstNode(os, handler_ast);
        }
        // rules
        size_t rulesCount = g.rules.size();
        os.write(reinterpret_cast<char *>(&rulesCount), sizeof(rulesCount));
        for(const GrammarRule & r : g.rules)
        {
            storeString(os, r.production);
            size_t patternCount = r.pattern.size();
            os.write(reinterpret_cast<char *>(&patternCount), sizeof(patternCount));
            for(const inout::Lexeme & s:r.pattern)
                storeLexeme(os, s);

            storeAstNode(os, r.reduce_action);

            RuleReduceDirection direction = r.reduce_direction;
            os.write(reinterpret_cast<char *>(&direction), sizeof(direction));
        }
        // columns
        size_t columnsCount = g.columns.size();
        os.write(reinterpret_cast<char *>(&columnsCount), sizeof(columnsCount));
        for(const inout::Lexeme & si : g.columns)
            storeLexeme(os, si);
        // first_compound_index
        os.write(reinterpret_cast<const char *>(&g.first_compound_index), sizeof(g.first_compound_index));
        // states
        size_t linesCount = g.states.size();
        os.write(reinterpret_cast<char *>(&linesCount), sizeof(linesCount));
        for(const FsmState_t & state : g.states)
        {
            size_t positionsCount = state.size();
            os.write(reinterpret_cast<char *>(&positionsCount), sizeof(positionsCount));
            for(const FsmStatePosition & p : state)
            {
                size_t main = p.is_main ? 1 : 0;
                size_t rno = p.rule_no;
                size_t pos = p.position;

                int next = p.next_state_no;

                os.write(reinterpret_cast<char *>(&rno), sizeof(rno));
                os.write(reinterpret_cast<char *>(&pos), sizeof(pos));
                os.write(reinterpret_cast<char *>(&next), sizeof(next));
                os.write(reinterpret_cast<char *>(&main), sizeof(main));

                // lookahead

                size_t lookahead_count = p.lookahead.size();

                os.write(reinterpret_cast<char *>(&lookahead_count), sizeof(lookahead_count));
                for(const inout::Lexeme & lex : p.lookahead)
                    storeLexeme(os, lex);
            }
        }
        // parse_table
        size_t FSMtableCount = g.parse_table.size();
        os.write(reinterpret_cast<char *>(&FSMtableCount), sizeof(FSMtableCount));
        for (auto [key,value] : g.parse_table)
        {
            os.write(reinterpret_cast<const char *>(&key), sizeof(key));
            os.write(reinterpret_cast<char *>(&value), sizeof(value));
        }
        // key_bound
        os.write(reinterpret_cast<const char *>(&g.key_bound), sizeof(g.key_bound));
        // value_bound
        os.write(reinterpret_cast<const char *>(&g.value_bound), sizeof(g.value_bound));
        // "end"
        storeString(os,u"end");

        os.close();
        return true;
    }

    bool loadGrammarDump(const std::string & grammar_file, Grammar & g)
    {
        fs::path grammar_dump_path  {   fs::path(grammar_file).parent_path() 
                                        / GRAMMAR_DUMP_DIRECTORY 
                                        / (fs::path(grammar_file).stem().string() + GRAMMAR_DUMP_EXTENTION)
                                    };

        std::ifstream is(grammar_dump_path, std::ios::binary);

        if (!is)
            return false;

        // GRAMMAR_DUMP_SIGNATURE
        bool success = (loadString(is) == GRAMMAR_DUMP_SIGNATURE);

        // build_method
        is.read(reinterpret_cast<char *>(&g.build_method), sizeof(g.build_method));

        // lexical
        size_t markups_count;
        is.read(reinterpret_cast<char *>(&markups_count), sizeof(markups_count));
        g.lexical.markups.reserve(markups_count);
        for(size_t i=0; i < markups_count; ++i)
        {
            inout::MarkupSymbol ms = loadMarkup(is);

            g.lexical.markups.push_back(ms);
        }
        size_t masks_count;
        is.read(reinterpret_cast<char *>(&masks_count), sizeof(masks_count));
        g.lexical.masks.reserve(masks_count);
        for(size_t i=0; i < masks_count; ++i)
        {
            inout::NumberMask mask = loadMask(is);

            g.lexical.masks.push_back(mask);
        }
        g.lexical.national_alphabet = loadString(is);
        g.lexical.id_extra_symbols = loadString(is);
        uint16_t may_national_letters_use;
        is.read(reinterpret_cast<char *>(&may_national_letters_use), sizeof(may_national_letters_use));
        g.lexical.may_national_letters_use = (may_national_letters_use != 0);
        uint16_t may_national_letters_mix;
        is.read(reinterpret_cast<char *>(&may_national_letters_mix), sizeof(may_national_letters_mix));
        g.lexical.may_national_letters_mix = (may_national_letters_mix != 0);
        uint16_t is_case_sensitive;
        is.read(reinterpret_cast<char *>(&is_case_sensitive), sizeof(is_case_sensitive));
        g.lexical.is_case_sensitive = (is_case_sensitive != 0);
        g.lexical.nl_substitution = loadString(is);

        // files
        size_t files_count;
        is.read(reinterpret_cast<char *>(&files_count), sizeof(files_count));
        g.files.reserve(files_count);
        for(size_t i=0; i < files_count; ++i)
        {
            std::string file_path = inout::toU8(loadString(is));

            g.files.push_back(file_path);
        }

        // handlers
        size_t handlers_count;
        is.read(reinterpret_cast<char *>(&handlers_count), sizeof(handlers_count));
        for(size_t i_handler=0; i_handler < handlers_count; ++i_handler)
        {
            std::u16string handler_name = loadString(is);
            ast::Node      handler_ast  = loadAstNode(is);

            g.handlers.insert({handler_name, handler_ast});
        }

        // rules
        size_t rulesCount;
        is.read(reinterpret_cast<char *>(&rulesCount), sizeof(rulesCount));
        g.rules.reserve(rulesCount);
        for(size_t iRule=0; iRule < rulesCount; ++iRule)
        {
            std::u16string production = loadString(is);

            size_t patternCount;
            is.read(reinterpret_cast<char *>(&patternCount), sizeof(patternCount));

            std::vector<inout::Lexeme> pattern;
            pattern.reserve(patternCount);
            for(size_t iPat=0; iPat < patternCount; ++iPat)
                pattern.emplace_back(loadLexeme(is));

            ast::Node node = loadAstNode(is);

            RuleReduceDirection direction;
            is.read(reinterpret_cast<char *>(&direction), sizeof(direction));

            g.rules.emplace_back(production, pattern, node, direction);
        }

        // columns
        size_t columnsCount;
        is.read(reinterpret_cast<char *>(&columnsCount), sizeof(columnsCount));
        g.columns.reserve(columnsCount);
        for(size_t iCol=0; iCol < columnsCount; ++iCol)
            g.columns.emplace_back(loadLexeme(is));

        // first_compound_index
        is.read(reinterpret_cast<char *>(&g.first_compound_index), sizeof(g.first_compound_index));

        // states
        size_t linesCount;
        is.read(reinterpret_cast<char *>(&linesCount), sizeof(linesCount));
        g.states.reserve(linesCount);
        for(size_t iLine=0; iLine < linesCount; ++iLine)
        {
            FsmState_t line;

            size_t positionsCount;
            is.read(reinterpret_cast<char *>(&positionsCount), sizeof(positionsCount));
            line.reserve(positionsCount);
            for(size_t iPos=0; iPos < positionsCount; ++iPos)
            {
                size_t rno;
                size_t pos;
                int    next;
                size_t main;

                std::set<inout::Lexeme> lookahead;

                is.read(reinterpret_cast<char *>(&rno), sizeof(rno));
                is.read(reinterpret_cast<char *>(&pos), sizeof(pos));
                is.read(reinterpret_cast<char *>(&next), sizeof(next));
                is.read(reinterpret_cast<char *>(&main), sizeof(main));
                // lookahead
                size_t lookahead_count;
                is.read(reinterpret_cast<char *>(&lookahead_count), sizeof(lookahead_count));
                for(size_t i=0; i < lookahead_count; ++i)
                {
                    inout::Lexeme lex = loadLexeme(is);
                    lookahead.insert(lex);
                }
                line.emplace_back(rno,pos,lookahead,main,next);
            }
            g.states.emplace_back(line);
        }

        // parse_table
        size_t FSMtableCount;
        is.read(reinterpret_cast<char *>(&FSMtableCount), sizeof(FSMtableCount));
        for(size_t iCount=0; iCount < FSMtableCount; ++iCount)
        {
            Fsm_key_t   key;
            Fsm_value_t value;

            is.read(reinterpret_cast<char *>(&key), sizeof(key));
            is.read(reinterpret_cast<char *>(&value), sizeof(value));

            g.parse_table.emplace(key,value);
        }

        // key_bound
        is.read(reinterpret_cast<char *>(&g.key_bound), sizeof(g.key_bound));

        // value_bound
        is.read(reinterpret_cast<char *>(&g.value_bound), sizeof(g.value_bound));

        // "end"
        std::u16string end = loadString(is);
        if (end == u"end")
            success = true;

        is.close();

        fillLexemeParameters(g);

        return success;
    }

}